package com.yql.biz.core.config;

import cn.hutool.json.JSONUtil;
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.spring.boot.autoconfigure.DruidDataSourceBuilder;
import com.google.common.collect.Range;
import com.yql.common.utils.DateUtils;
import com.yql.framework.config.properties.DruidProperties;
import org.apache.shardingsphere.api.config.sharding.KeyGeneratorConfiguration;
import org.apache.shardingsphere.api.config.sharding.ShardingRuleConfiguration;
import org.apache.shardingsphere.api.config.sharding.TableRuleConfiguration;
import org.apache.shardingsphere.api.config.sharding.strategy.StandardShardingStrategyConfiguration;
import org.apache.shardingsphere.api.sharding.standard.PreciseShardingAlgorithm;
import org.apache.shardingsphere.api.sharding.standard.PreciseShardingValue;
import org.apache.shardingsphere.api.sharding.standard.RangeShardingAlgorithm;
import org.apache.shardingsphere.api.sharding.standard.RangeShardingValue;
import org.apache.shardingsphere.shardingjdbc.api.ShardingDataSourceFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import javax.sql.DataSource;
import java.sql.SQLException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.*;

/**
 * sharding 配置信息
 *
 */
@Configuration
public class ShardingDataSourceConfig
{

    private static Logger log = LoggerFactory.getLogger(ShardingDataSourceConfig.class);
    /**
     * @description: 第一个数据源
     * @param:
     * @return:
     * @author
     */
    @Bean
    @ConfigurationProperties("spring.datasource.druid.order1")
    @ConditionalOnProperty(prefix = "spring.datasource.druid.order1", name = "enabled", havingValue = "true")
    public DataSource order1DataSource(DruidProperties druidProperties)
    {
        DruidDataSource dataSource = DruidDataSourceBuilder.create().build();
        return druidProperties.dataSource(dataSource);
    }


    /**
     * @description:  第二个数据源 若分库则需要指定第二个数据源
     * @param:
     * @return:
     * @author
     */
    /*@Bean
    @ConfigurationProperties("spring.datasource.druid.order2")
    @ConditionalOnProperty(prefix = "spring.datasource.druid.order2", name = "enabled", havingValue = "true")
    public DataSource order2DataSource(DruidProperties druidProperties)
    {
        DruidDataSource dataSource = DruidDataSourceBuilder.create().build();
        return druidProperties.dataSource(dataSource);
    }*/
    /**
     * @description: 取模分表
     * @param:
     * @return:
     * @author
     */
    @Bean(name = "shardingDataSource")
    public DataSource shardingDataSource(@Qualifier("order1DataSource") DataSource order1DataSource) throws SQLException
    {
        Map<String, DataSource> dataSourceMap = new HashMap<>();
        dataSourceMap.put("order1", order1DataSource);
        //dataSourceMap.put("order2", order2DataSource);

        //===================================1. 指定数据库表节点 ================================================
        // 行表达式指定 取模指定
        TableRuleConfiguration orderTableRuleConfig = new TableRuleConfiguration("sys_order", "order1.sys_order_$->{0..1}");
        // =================================2.配置分库策略 =========================================================
        //orderTableRuleConfig.setDatabaseShardingStrategyConfig(new InlineShardingStrategyConfiguration("user_id", "order$->{user_id % 2 + 1}"));
        //==================================3.配置分表策略 =========================================================
        //配置分表策略--2  根据年月分表 (前提 数据库中要有对应的年月表 例如 sys_order_202107)
        orderTableRuleConfig.setTableShardingStrategyConfig(new StandardShardingStrategyConfiguration("order_id",new OrderPreciseShardingAlgorithm(),new OrderRangeShardingAlgorithm()));
        // 分布式主键
        orderTableRuleConfig.setKeyGeneratorConfig(new KeyGeneratorConfiguration("SNOWFLAKE", "order_id"));
        // 配置分片规则
        ShardingRuleConfiguration shardingRuleConfig = new ShardingRuleConfiguration();
        shardingRuleConfig.getTableRuleConfigs().add(orderTableRuleConfig);
        // 获取数据源对象
        DataSource dataSource = ShardingDataSourceFactory.createDataSource(dataSourceMap, shardingRuleConfig, getProperties());
        return dataSource;
    }
    /**
     * @description: 年月日分表
     * @param:
     * @return:
     * @author
     */
    /*@Bean(name = "shardingDataSource")
    public DataSource shardingDataSource(@Qualifier("order1DataSource") DataSource order1DataSource) throws SQLException
    {
        Map<String, DataSource> dataSourceMap = new HashMap<>();
        dataSourceMap.put("order1", order1DataSource);
        //===================================1. 指定数据库表节点 ================================================
        //1.1 行表达式指定 取模指定
        TableRuleConfiguration orderTableRuleConfig = new TableRuleConfiguration("sys_order", "order1.sys_order_$->{2021..2099}0$->{1..9},order1.sys_order_$->{2000..2099}1$->{0..2}");
        // =================================2.配置分库策略 =========================================================
        //orderTableRuleConfig.setDatabaseShardingStrategyConfig(new InlineShardingStrategyConfiguration("user_id", "order$->{user_id % 2 + 1}"));
        //==================================3.配置分表策略 =========================================================
        //配置分表策略  根据年月分表 (前提 数据库中要有对应的年月表 例如 sys_order_202107) 分片键可根据业务自行配置
        orderTableRuleConfig.setTableShardingStrategyConfig(new StandardShardingStrategyConfiguration("add_time",new OrderPreciseShardingAlgorithmByDate(),new OrderRangeShardingAlgorithmByDate()));
        //分布式主键
        orderTableRuleConfig.setKeyGeneratorConfig(new KeyGeneratorConfiguration("SNOWFLAKE", "order_id"));
        // 配置分片规则
        ShardingRuleConfiguration shardingRuleConfig = new ShardingRuleConfiguration();
        shardingRuleConfig.getTableRuleConfigs().add(orderTableRuleConfig);
        // 获取数据源对象
        DataSource dataSource = ShardingDataSourceFactory.createDataSource(dataSourceMap, shardingRuleConfig, getProperties());
        return dataSource;
    }*/





    private Properties getProperties()
    {
        Properties shardingProperties = new Properties();
        shardingProperties.put("sql.show", true);
        return shardingProperties;
    }
}

//PreciseShardingAlgorithm接口实现（用于处理 = 和 in 的路由）
class OrderPreciseShardingAlgorithm implements PreciseShardingAlgorithm<Long> {
    Logger logger = LoggerFactory.getLogger(OrderPreciseShardingAlgorithm.class);
    @Override
    public String doSharding(Collection<String> collection, PreciseShardingValue<Long> preciseShardingValue) {
        logger.info("collection:" + JSONUtil.toJsonStr(collection) + ",preciseShardingValue:" + JSONUtil.toJsonStr(preciseShardingValue));
        for (String name : collection) {
            if (name.endsWith(preciseShardingValue.getValue() % collection.size() + "")) {
                logger.info("return name:"+name);
                return name;
            }
        }
        return null;
    }
}

//RangeShardingAlgorithm接口实现（用于处理BETWEEN AND分片），这里的核心是找出这个范围的数据分布在那些表(库)中
class OrderRangeShardingAlgorithm implements RangeShardingAlgorithm<Long> {
    Logger logger = LoggerFactory.getLogger(OrderRangeShardingAlgorithm.class);
    @Override
    public Collection<String> doSharding(Collection<String> collection, RangeShardingValue<Long> rangeShardingValue) {
        logger.info("Range collection:" + JSONUtil.toJsonStr(collection) + ",rangeShardingValue:" + JSONUtil.toJsonStr(rangeShardingValue));
        Collection<String> collect = new ArrayList<>();
        Range<Long> valueRange = rangeShardingValue.getValueRange();
        for (Long i = valueRange.lowerEndpoint(); i <= valueRange.upperEndpoint(); i++) {
            for (String each : collection) {
                if (each.endsWith(i % collection.size() + "")) {
                    collect.add(each);
                }
            }
        }
        return collect;
    }
}

class OrderPreciseShardingAlgorithmByDate implements PreciseShardingAlgorithm<Long>{

    @Override
    public String doSharding(Collection<String> collection, PreciseShardingValue<Long> preciseShardingValue) {
        StringBuffer tableName = new StringBuffer();
        String date = DateUtils.secondToDateStr(preciseShardingValue.getValue(),"yyyyMM");
        //此处是根据逻辑表明拼接，若逻辑表名与真实表明前缀不一样 则需根据自己业务配置
        tableName.append(preciseShardingValue.getLogicTableName()).append("_").append(date);
        return tableName.toString();
    }
}

class OrderRangeShardingAlgorithmByDate implements RangeShardingAlgorithm<Long> {
    static Logger logger = LoggerFactory.getLogger(OrderRangeShardingAlgorithmByDate.class);
    @Override
    public Collection<String> doSharding(Collection<String> collection, RangeShardingValue<Long> rangeShardingValue) {
        logger.info("Range collection:" + JSONUtil.toJsonStr(collection) + ",rangeShardingValue:" + JSONUtil.toJsonStr(rangeShardingValue));
        Long lower = rangeShardingValue.getValueRange().lowerEndpoint();
        Long upper = rangeShardingValue.getValueRange().upperEndpoint();
        Set<String> result = new LinkedHashSet<>();
        String minDate = DateUtils.secondToDateStr(lower, "yyyyMM");
        String maxDate = DateUtils.secondToDateStr(upper, "yyyyMM");
        if (minDate.equals(maxDate)){
            StringBuffer tableName = new StringBuffer();
            tableName.append(rangeShardingValue.getLogicTableName()).append("_").append(minDate);
            result.add(tableName.toString());
            return result;
        }
        List<String> dateList = getMonth(minDate, maxDate);
        for (String date : dateList){
            StringBuffer tableName = new StringBuffer();
            tableName.append(rangeShardingValue.getLogicTableName()).append("_").append(date);
            result.add(tableName.toString());
        }
        return result;
    }
    /**
     * 获取连个日期之间相差的月份
     * @param startDate
     * @param endDate
     * @return
     * @throws
     */
    private static List getMonth(String startDate, String endDate){
        List list = new ArrayList();
        SimpleDateFormat sdf = new SimpleDateFormat("yyyyMM");
        Calendar c1 = Calendar.getInstance();
        Calendar c2 = Calendar.getInstance();
        try {
            c1.setTime(sdf.parse(startDate));
            c2.setTime(sdf.parse(endDate));
            int year = c2.get(Calendar.YEAR) - c1.get(Calendar.YEAR);
            int month = c2.get(Calendar.MONTH) + year * 12 - c1.get(Calendar.MONTH);
            for (int i = 0; i <= month; i++) {
                c1.setTime(sdf.parse(startDate));
                c1.add(c1.MONTH, i);
                list.add(sdf.format(c1.getTime()));
            }
        } catch (Exception e) {
            logger.error("获取日期之间的月份失败 getMonth===>",e);
        }
        return list;
    }

}